"""

    PaCo Algorithm
    [Co-Clustering Algorithm]

    Literature:
        Michail Vlachos, Francesco Fusco, Charalambos Mavroforakis, Anastasios Kyrillidis, and
        Vassilios G. Vassiliadis:
        Improving Co-Cluster Quality with Application to Product Recommendations. 2014.
        http://dl.acm.org/citation.cfm?id=2661980

"""

# © 2018. Case Recommender (MIT License)

import itertools
import random

import numpy as np
from scipy.spatial.distance import squareform, pdist
from sklearn.cluster import KMeans

from caserec.utils.process_data import ReadFile

__author__ = 'removed for double blind review'


class PaCo(object):
    def __init__(self, train_file, k_row=None, l_col=None, density_low=0.008, as_binary=True,
                 sep='\t', random_seed=None):
        """
        PaCo: EntroPy Anomalies in Co-Clustering

        Usage::

            >> PaCo(train, 10, 15).compute()

        :param train_file: File which contains the train set. This file needs to have at least 3 columns
        (user item feedback_value).
        :type train_file: str

        :param k_row: Number of clusters generated by k-means in rows
        :type k_row: int, default None

        :param l_col: (int) Number of clusters generated by k-means in rows
        :type l_col: int, default None

        :param density_low: Threshold to change the density matrix values
        :type density_low: float, default 0.008

        :param as_binary: If True, the explicit feedback will be transform to binary
        :type as_binary: bool, default True

        :param sep: Delimiter for input files
        :type sep: str, default '\t'

        :param random_seed: Number of seed. Lock random numbers for reproducibility of experiments.
        :type random_seed: int, default None

        """

        self.train_set = ReadFile(train_file, as_binary=as_binary, sep=sep).read()
        self.density_low = density_low
        self.users = self.train_set['users']
        self.items = self.train_set['items']

        if random_seed is not None:
            np.random.seed(random_seed)
            random.seed(random_seed)

        if k_row is None:
            self.k_row = int(np.sqrt(len(self.users)))
        else:
            self.k_row = k_row

        if l_col is None:
            self.l_col = int(np.sqrt(len(self.items)))
        else:
            self.l_col = l_col

        self.list_row = [list() for _ in range(self.k_row)]
        self.list_col = [list() for _ in range(self.l_col)]

        self.item_to_item_id = {}
        self.item_id_to_item = {}
        self.user_to_user_id = {}
        self.user_id_to_user = {}

        for i, item in enumerate(self.items):
            self.item_to_item_id.update({item: i})
            self.item_id_to_item.update({i: item})
        for u, user in enumerate(self.users):
            self.user_to_user_id.update({user: u})
            self.user_id_to_user.update({u: user})

        self.count_total, self.count_ones = list(), list()
        self.density = None
        self.delta_entropy = list()
        self.matrix = []

        self.create_matrix()

    def create_matrix(self):
        """
        Method to create a feedback matrix

        """

        self.matrix = np.zeros((len(self.users), len(self.items)))

        for user in self.train_set['users']:
            for item in self.train_set['feedback'][user]:
                self.matrix[self.user_to_user_id[user]][self.item_to_item_id[item]] = \
                    self.train_set['feedback'][user][item]

    def run_kmeans(self):
        """
        Method to apply kmeans++ to rows and cols

        """
        clusters_rows = KMeans(n_clusters=self.k_row, init='k-means++').fit(self.matrix)
        clusters_cols = KMeans(n_clusters=self.l_col, init='k-means++').fit(self.matrix.T)

        # Map inverse index
        [self.list_row[label].append(row_id) for row_id, label in enumerate(clusters_rows.labels_)]
        [self.list_col[label].append(col_id) for col_id, label in enumerate(clusters_cols.labels_)]

    def count_information(self):
        """
        Method to count the number of interaction in each bi-cluster

        """

        for label_row in range(self.k_row):
            for label_col in range(self.l_col):
                count_local = 0

                for pair in itertools.product(self.list_row[label_row], self.list_col[label_col]):
                    if self.matrix[pair[0]][pair[1]] != 0:
                        count_local += 1

                self.count_total.append(len(self.list_row[label_row]) * len(self.list_col[label_col]))
                self.count_ones.append(count_local)

        self.update_information(first_iteration=True)

    def update_information(self, first_iteration=False):
        """
        Method to update information after a epoach

        :param first_iteration: if True calculate self.count_total and self.count_ones
        :type first_iteration: bool, default False

        """

        if first_iteration:
            self.count_total = np.matrix(self.count_total).reshape((self.k_row, self.l_col))
            self.count_ones = np.matrix(self.count_ones).reshape((self.k_row, self.l_col))
            self.density = np.matrix(np.divide(self.count_ones, self.count_total))
            # self.density = np.matrix(np.divide(self.count_ones, self.count_total)).reshape((self.k_row, self.l_col))
        else:
            self.density = np.matrix(np.divide(self.count_ones, self.count_total))
            self.density[self.density < self.density_low] = .0

    def calculate_entropy(self):
        """
        Method to calculate the entropy in each epoach

        """

        total_density = self.density.sum()
        probability = np.divide(self.density, total_density)

        sum_pi = 0
        for pi in probability.flat:
            sum_pi += 0 if pi == 0 else pi * np.log2(pi)

        return (-sum_pi) / np.log2(probability.size)

    @staticmethod
    def return_min_value(matrix):
        """
        Method to find the min value in a matrix

        """

        min_value = (float('inf'), (0, 0))
        for i in range(len(matrix)):
            for j in range(i):
                if matrix[i][j] < min_value[0]:
                    min_value = (matrix[i][j], (i, j))

        return min_value

    def merge(self, min_value_row, min_value_col):
        """
        Method to combine two bi-cluters based on min value (Merge on line or columns)

        """

        if min_value_row[0] > min_value_col[0]:

            # merge of columns
            pair = min_value_col[1]

            new_set_col = self.list_col[pair[0]].copy() + self.list_col[pair[1]].copy()
            self.list_col = list(np.delete(self.list_col, [pair[0], pair[1]], axis=0))
            self.list_col.append(new_set_col)

            # update count total based on columns
            new_count_total = self.count_total[:, pair[0]] + self.count_total[:, pair[1]]
            self.count_total = np.delete(self.count_total, (pair[0], pair[1]), axis=1)
            self.count_total = np.insert(self.count_total, self.count_total.shape[1], new_count_total.T, axis=1)

            # update count ones based on columns
            new_count_ones = self.count_ones[:, pair[0]] + self.count_ones[:, pair[1]]
            self.count_ones = np.delete(self.count_ones, (pair[0], pair[1]), axis=1)
            self.count_ones = np.insert(self.count_ones, self.count_ones.shape[1], new_count_ones.T, axis=1)

        else:
            # merge of rows
            pair = min_value_row[1]

            new_set_row = self.list_row[pair[0]].copy() + self.list_row[pair[1]].copy()
            self.list_row = list(np.delete(self.list_row, [pair[0], pair[1]], axis=0))
            self.list_row.append(new_set_row)

            # update count total based on rows
            new_count_total = self.count_total[pair[0], :] + self.count_total[pair[1], :]
            self.count_total = np.delete(self.count_total, (pair[0], pair[1]), axis=0)
            self.count_total = np.insert(self.count_total, self.count_total.shape[0], new_count_total, axis=0)

            # update count ones based on rows
            new_count_ones = self.count_ones[pair[0], :] + self.count_ones[pair[1], :]
            self.count_ones = np.delete(self.count_ones, (pair[0], pair[1]), axis=0)
            self.count_ones = np.insert(self.count_ones, self.count_ones.shape[0], new_count_ones, axis=0)

        self.update_information()

    def fit(self):
        """
        This method performs iterations of combination of bi-clusters.

        """

        count_epoch = 0
        criteria = True
        # 1st step: run k-means
        self.run_kmeans()
        # 2st step: collect information (only one time)
        self.count_information()

        entropy0 = self.calculate_entropy()

        # 3st step: training the algorithm
        while criteria:
            old_density, old_list_row, old_list_col = self.density.copy(), self.list_row.copy(), self.list_col.copy()
            distance_rows = np.divide(np.float32(squareform(pdist(self.density, 'euclidean'))), self.density.shape[1])
            distance_cols = np.divide(np.float32(squareform(pdist(self.density.T, 'euclidean'))), self.density.shape[0])
            min_row = self.return_min_value(distance_rows)
            min_col = self.return_min_value(distance_cols)

            self.merge(min_row, min_col)

            # Check the number os bi-clusters
            if len(self.list_row) == 1 and len(self.list_col) == 1:
                break

            entropy = self.calculate_entropy()
            dif_entropy = entropy - entropy0
            self.delta_entropy.append(dif_entropy)
            mean_range, std_range = np.mean(self.delta_entropy), np.std(self.delta_entropy)

            if not (mean_range - 3 * std_range <= dif_entropy <= mean_range + 3 * std_range):
                self.density, self.list_row, self.list_col = old_density, old_list_row, old_list_col
                criteria = False
            else:
                entropy0 = entropy
                count_epoch += 1

        return entropy0

    def filter_relevant_bi_groups(self):
        """
        This method is responsible to filter the bi-groups removing the ones with lower density,
        leaving the minimum to recommend to every user

        """

        filtered_densities = self.density.copy()
        filtered_densities[filtered_densities == 1] = np.nan

        first_run = True
        old = filtered_densities.copy()
        while True:
            for line in filtered_densities:
                print(np.nansum(line))
                if np.nansum(line) == 0:
                    if first_run:
                        return np.logical_not(np.isnan(self.density))
                    else:
                        return np.logical_not(np.isnan(old))

            old = filtered_densities.copy()

            # replace min in filteredDensities value per nan
            index = np.nanargmin(filtered_densities)
            filtered_densities.flat[index] = np.nan

            first_run = False

    def compute(self, verbose=True):
        """
        Method to run recommender algorithm

        :param verbose: Print recommender and database information
        :type verbose: bool, default True

        """

        if verbose:
            print("[Case Recommender: %s]\n" % 'PaCo: Co-Clustering Algorithm')
            print("Final entropy::", self.fit())
            print("K rows:: ", len(self.list_row), "and L columns:: ", len(self.list_col))
            print("Number of bi-groups:: ", len(self.list_row) * len(self.list_col))
            print("Number of bi-groups needing recommendations:: ", self.density[np.logical_and(
                self.density != 1, self.density != 0)].size)
        else:
            self.fit()
